{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.contrib.layers import fully_connected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "n_steps = 28\n",
    "n_inputs = 28\n",
    "n_neurons = 150\n",
    "n_outputs = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "learning_rate = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])\n",
    "y = tf.placeholder(tf.int32, [None])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)\n",
    "outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "logits = fully_connected(states, n_outputs, activation_fn=None)\n",
    "xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "loss = tf.reduce_mean(xentropy)\n",
    "optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n",
    "training_op = optimizer.minimize(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "correct = tf.nn.in_top_k(logits, y, 1)\n",
    "accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "init = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /tmp/data/train-images-idx3-ubyte.gz\n",
      "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n",
      "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n",
      "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "mnist = input_data.read_data_sets(\"/tmp/data/\")\n",
    "X_test = mnist.test.images.reshape((-1, n_steps, n_inputs))\n",
    "y_test = mnist.test.labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "n_epochs = 50\n",
    "batch_size = 150"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 Train accuracy: 0.94  Test accuracy: 0.9127\n",
      "1 Train accuracy: 0.973333  Test accuracy: 0.954\n",
      "2 Train accuracy: 0.96  Test accuracy: 0.9549\n",
      "3 Train accuracy: 0.986667  Test accuracy: 0.9649\n",
      "4 Train accuracy: 0.966667  Test accuracy: 0.9692\n",
      "5 Train accuracy: 1.0  Test accuracy: 0.9708\n",
      "6 Train accuracy: 0.98  Test accuracy: 0.9663\n",
      "7 Train accuracy: 0.986667  Test accuracy: 0.967\n",
      "8 Train accuracy: 0.986667  Test accuracy: 0.969\n",
      "9 Train accuracy: 0.993333  Test accuracy: 0.9714\n",
      "10 Train accuracy: 0.98  Test accuracy: 0.9756\n",
      "11 Train accuracy: 1.0  Test accuracy: 0.971\n",
      "12 Train accuracy: 0.966667  Test accuracy: 0.968\n",
      "13 Train accuracy: 0.973333  Test accuracy: 0.9761\n",
      "14 Train accuracy: 0.973333  Test accuracy: 0.9779\n",
      "15 Train accuracy: 0.966667  Test accuracy: 0.9765\n",
      "16 Train accuracy: 0.993333  Test accuracy: 0.9775\n",
      "17 Train accuracy: 0.993333  Test accuracy: 0.9737\n",
      "18 Train accuracy: 0.993333  Test accuracy: 0.9762\n",
      "19 Train accuracy: 0.98  Test accuracy: 0.9754\n",
      "20 Train accuracy: 0.973333  Test accuracy: 0.9755\n",
      "21 Train accuracy: 1.0  Test accuracy: 0.9779\n",
      "22 Train accuracy: 0.98  Test accuracy: 0.9759\n",
      "23 Train accuracy: 0.993333  Test accuracy: 0.9794\n",
      "24 Train accuracy: 1.0  Test accuracy: 0.9775\n",
      "25 Train accuracy: 0.993333  Test accuracy: 0.9799\n",
      "26 Train accuracy: 1.0  Test accuracy: 0.9808\n",
      "27 Train accuracy: 0.993333  Test accuracy: 0.9754\n",
      "28 Train accuracy: 0.973333  Test accuracy: 0.977\n",
      "29 Train accuracy: 0.993333  Test accuracy: 0.9792\n",
      "30 Train accuracy: 1.0  Test accuracy: 0.9779\n",
      "31 Train accuracy: 0.986667  Test accuracy: 0.9727\n",
      "32 Train accuracy: 0.986667  Test accuracy: 0.9799\n",
      "33 Train accuracy: 0.993333  Test accuracy: 0.9782\n",
      "34 Train accuracy: 1.0  Test accuracy: 0.9781\n",
      "35 Train accuracy: 0.98  Test accuracy: 0.9786\n",
      "36 Train accuracy: 0.986667  Test accuracy: 0.979\n",
      "37 Train accuracy: 0.986667  Test accuracy: 0.9813\n",
      "38 Train accuracy: 0.986667  Test accuracy: 0.9702\n",
      "39 Train accuracy: 0.986667  Test accuracy: 0.9786\n",
      "40 Train accuracy: 0.986667  Test accuracy: 0.9779\n",
      "41 Train accuracy: 0.98  Test accuracy: 0.9755\n",
      "42 Train accuracy: 0.973333  Test accuracy: 0.9769\n",
      "43 Train accuracy: 0.98  Test accuracy: 0.9701\n",
      "44 Train accuracy: 0.986667  Test accuracy: 0.9768\n",
      "45 Train accuracy: 0.993333  Test accuracy: 0.9779\n",
      "46 Train accuracy: 0.993333  Test accuracy: 0.9779\n",
      "47 Train accuracy: 1.0  Test accuracy: 0.9788\n",
      "48 Train accuracy: 0.98  Test accuracy: 0.9784\n",
      "49 Train accuracy: 1.0  Test accuracy: 0.9782\n",
      "50 Train accuracy: 1.0  Test accuracy: 0.9781\n",
      "51 Train accuracy: 1.0  Test accuracy: 0.9732\n",
      "52 Train accuracy: 0.973333  Test accuracy: 0.9784\n",
      "53 Train accuracy: 0.993333  Test accuracy: 0.9792\n",
      "54 Train accuracy: 1.0  Test accuracy: 0.9824\n",
      "55 Train accuracy: 1.0  Test accuracy: 0.9758\n",
      "56 Train accuracy: 0.986667  Test accuracy: 0.9715\n",
      "57 Train accuracy: 0.993333  Test accuracy: 0.9727\n",
      "58 Train accuracy: 1.0  Test accuracy: 0.9812\n",
      "59 Train accuracy: 1.0  Test accuracy: 0.9801\n",
      "60 Train accuracy: 0.993333  Test accuracy: 0.9796\n",
      "61 Train accuracy: 0.993333  Test accuracy: 0.9748\n",
      "62 Train accuracy: 0.986667  Test accuracy: 0.9774\n",
      "63 Train accuracy: 1.0  Test accuracy: 0.9792\n",
      "64 Train accuracy: 0.993333  Test accuracy: 0.9716\n",
      "65 Train accuracy: 0.973333  Test accuracy: 0.9776\n",
      "66 Train accuracy: 0.98  Test accuracy: 0.9801\n",
      "67 Train accuracy: 0.993333  Test accuracy: 0.9755\n",
      "68 Train accuracy: 0.96  Test accuracy: 0.9717\n",
      "69 Train accuracy: 1.0  Test accuracy: 0.9789\n",
      "70 Train accuracy: 0.993333  Test accuracy: 0.9809\n",
      "71 Train accuracy: 1.0  Test accuracy: 0.9792\n",
      "72 Train accuracy: 1.0  Test accuracy: 0.9792\n",
      "73 Train accuracy: 0.986667  Test accuracy: 0.9821\n",
      "74 Train accuracy: 0.993333  Test accuracy: 0.9813\n",
      "75 Train accuracy: 0.98  Test accuracy: 0.9752\n",
      "76 Train accuracy: 1.0  Test accuracy: 0.976\n",
      "77 Train accuracy: 0.993333  Test accuracy: 0.9766\n",
      "78 Train accuracy: 1.0  Test accuracy: 0.9777\n",
      "79 Train accuracy: 1.0  Test accuracy: 0.9816\n",
      "80 Train accuracy: 0.986667  Test accuracy: 0.9794\n",
      "81 Train accuracy: 0.973333  Test accuracy: 0.9763\n",
      "82 Train accuracy: 0.986667  Test accuracy: 0.9758\n",
      "83 Train accuracy: 1.0  Test accuracy: 0.9801\n",
      "84 Train accuracy: 1.0  Test accuracy: 0.9822\n",
      "85 Train accuracy: 0.953333  Test accuracy: 0.9714\n",
      "86 Train accuracy: 0.993333  Test accuracy: 0.9812\n",
      "87 Train accuracy: 1.0  Test accuracy: 0.9803\n",
      "88 Train accuracy: 0.993333  Test accuracy: 0.9798\n",
      "89 Train accuracy: 1.0  Test accuracy: 0.9803\n",
      "90 Train accuracy: 0.986667  Test accuracy: 0.9754\n",
      "91 Train accuracy: 1.0  Test accuracy: 0.9794\n",
      "92 Train accuracy: 1.0  Test accuracy: 0.9821\n",
      "93 Train accuracy: 1.0  Test accuracy: 0.9824\n",
      "94 Train accuracy: 0.986667  Test accuracy: 0.9817\n",
      "95 Train accuracy: 0.993333  Test accuracy: 0.9803\n",
      "96 Train accuracy: 0.986667  Test accuracy: 0.9796\n",
      "97 Train accuracy: 0.986667  Test accuracy: 0.9781\n",
      "98 Train accuracy: 1.0  Test accuracy: 0.9784\n",
      "99 Train accuracy: 1.0  Test accuracy: 0.9805\n"
     ]
    }
   ],
   "source": [
    "with tf.Session() as sess:\n",
    "    init.run()\n",
    "    for epoch in range(n_epochs):\n",
    "        for iteration in range(mnist.train.num_examples // batch_size):\n",
    "            X_batch, y_batch = mnist.train.next_batch(batch_size)\n",
    "            X_batch = X_batch.reshape((-1, n_steps, n_inputs))\n",
    "            sess.run(training_op, feed_dict={X: X_batch, y: y_batch})\n",
    "        acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})\n",
    "        acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})\n",
    "        print(epoch, \"Train accuracy:\", acc_train, \" Test accuracy:\", acc_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
